package mongox

import (
	"context"
	"fmt"
	"gitee.com/zhongguo168a/gocodes/datax"
	"gitee.com/zhongguo168a/gocodes/myx/errorx"
	"go.mongodb.org/mongo-driver/mongo"
)

func NewTransactionRequest(dbName string) (obj *TransactionRequest) {
	obj = &TransactionRequest{}
	obj.dbName = dbName
	return
}

func (req *TransactionRequest) NewTableRequest(tableName string) *Request {
	tableReq := NewRequestByName(tableName).WithContext(req.sessionCtx)
	return tableReq
}

type TransactionRequest struct {
	ctx        context.Context
	sessionCtx mongo.SessionContext
	db         *Database
	dbName     string
	dbsubkey   string
	callbacks  []func(req *TransactionRequest) (interface{}, error)
	results    []interface{}
}

func (req *TransactionRequest) WithContext(ctx context.Context) *TransactionRequest {
	req.ctx = ctx
	return req
}

func (req *TransactionRequest) WithCallback(callback func(req *TransactionRequest) (interface{}, error)) *TransactionRequest {
	req.callbacks = append(req.callbacks, callback)
	return req
}

func (req *TransactionRequest) getDatabase() (*Database, error) {
	if req.db == nil {
		db, err := getDatabase(req.dbName)
		if err != nil {
			return nil, err
		}
		req.db = db
	}
	return req.db, nil
}

func (req *TransactionRequest) getClient() (*mongo.Client, error) {
	db, err := req.getDatabase()
	if err != nil {
		return nil, err
	}

	addr, err := db.GetAddr(req.dbsubkey)
	if err != nil {
		return nil, errorx.Wrap(err, "get db addr", datax.M{"subkey": req.dbsubkey})
	}

	client, err := addr.GetConnection()
	if err != nil {
		return nil, errorx.Wrap(err, "get connection", datax.M{"database": req.dbName})
	}
	return client, nil

}
func (req *TransactionRequest) getContext() context.Context {
	if req.ctx != nil {
		return req.ctx
	}

	return context.Background()
}

// Commit 提交事务
// 返回的结果是 WithCallback 返回的结果集合。数组的顺序为 WithCallback 注册的顺序
func (req *TransactionRequest) Commit() ([]interface{}, error) {
	client, err := req.getClient()
	if err != nil {
		err = errorx.Wrap(err, fmt.Sprintf("get client"))
		return nil, err
	}

	ctx := req.getContext()

	session, err := client.StartSession()
	if err != nil {
		return nil, err
	}
	defer session.EndSession(ctx)

	result, err := session.WithTransaction(ctx, func(sessCtx mongo.SessionContext) (interface{}, error) {
		for _, val := range req.callbacks {
			req.sessionCtx = sessCtx
			result, callbackErr := val(req)
			if callbackErr != nil {
				return nil, callbackErr
			}
			req.results = append(req.results, result)
		}
		return req.results, nil
	})
	if err != nil {
		return nil, err
	}
	return result.([]interface{}), nil
}
